#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
from tensorflow import keras
import os
import gzip
import matplotlib.pyplot as plt
import DonaldDuckFunc
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import RandomFlip, RandomZoom, RandomRotation


class DonaldDuckDataset():
    def preprocess(self):

        if keras.backend.image_data_format() == 'channels_first':
            self.x_train = self.x_train.reshape(self.x_train.shape[0], self.channels, self.img_rows, self.img_cols)
            self.x_test = self.x_test.reshape(self.x_test.shape[0], self.channels, self.img_rows, self.img_cols)
            input_shape = (self.channels, self.img_rows, self.img_cols)
        else:
            self.x_train = self.x_train.reshape(self.x_train.shape[0], self.img_rows, self.img_cols, self.channels)
            self.x_test = self.x_test.reshape(self.x_test.shape[0], self.img_rows, self.img_cols, self.channels)
            input_shape = (self.img_rows, self.img_cols, self.channels)

        self.x_train = self.x_train.astype('float32') / 255  
        self.x_test = self.x_test.astype('float32') / 255  
        self.data_mean = np.mean(self.x_train, axis=0)
        self.data_std = np.std(self.x_train)
        if self.standardization:
            self.x_train = self.standardize(self.x_train)
            self.x_test = self.standardize(self.x_test)
        self.y_train = keras.utils.to_categorical(self.y_train, self.num_classes)
        self.y_test = keras.utils.to_categorical(self.y_test, self.num_classes)

    def getData(self):
        return (self.x_train, self.y_train), (self.x_test, self.y_test)

    def standardize(self, x):
        return (x - self.data_mean) / self.data_std

    def unstandardize(self, x):
        return x * self.data_std + self.data_mean

    def clip(self, x, lower_bound=0, upper_bound=1):
        x = np.where(x < lower_bound, lower_bound, x)
        x = np.where(x > upper_bound, upper_bound, x)
        return x

    def restore(self, x):
        if self.standardization:
            x = self.unstandardize(x)
        x = self.clip(x)
        if self.standardization:
            x = self.standardize(x)
        return x

    def data_augmentation(self):
        self.x_train = np.vstack((self.x_train, RandomFlip("horizontal_and_vertical")(self.x_train).numpy()))
        self.y_train = np.vstack((self.y_train, self.y_train))

        self.x_train = np.vstack((self.x_train, RandomRotation(0.2)(self.x_train).numpy()))
        self.y_train = np.vstack((self.y_train, self.y_train))

        self.x_train = np.vstack((self.x_train, RandomZoom(height_factor=(-0.2, -0.3))(self.x_train).numpy()))
        self.y_train = np.vstack((self.y_train, self.y_train))


class MNIST(DonaldDuckDataset):
    def __init__(self, standardization=False):
        self.standardization = standardization
        self.img_rows, self.img_cols = 28, 28
        self.channels = 1
        self.num_classes = 10
        self.input_shape = (self.img_rows, self.img_cols, self.channels)
        (self.x_train, self.y_train), (self.x_test, self.y_test) = keras.datasets.mnist.load_data()
        self.preprocess()
        self.name = 'mnist'


class CIFAR10(DonaldDuckDataset):
    def __init__(self, standardization=False, data_aug_flag=False):
        self.standardization = standardization
        self.img_rows, self.img_cols = 32, 32
        self.channels = 3
        self.num_classes = 10
        self.input_shape = (self.img_rows, self.img_cols, self.channels)
        (self.x_train, self.y_train), (self.x_test, self.y_test) = keras.datasets.cifar10.load_data()
        if data_aug_flag:
            self.data_augmentation()
        self.preprocess()
        self.name = 'cifar10'


class Fashion(DonaldDuckDataset):
    def __init__(self, standardization=False, dataPath=r'data//fashion'):
        self.standardization = standardization
        self.img_rows, self.img_cols = 28, 28
        self.channels = 1
        self.num_classes = 10
        self.input_shape = (self.img_rows, self.img_cols, self.channels)
        self.dataPath = dataPath
        (self.x_train, self.y_train), (self.x_test, self.y_test) = self.load_fashion()
        self.preprocess()
        self.name = 'fashion'

    def load_fashion(self):
        train_labels_path = os.path.join(self.dataPath, 'train-labels-idx1-ubyte.gz')
        train_images_path = os.path.join(self.dataPath, 'train-images-idx3-ubyte.gz')
        test_labels_path = os.path.join(self.dataPath, 't10k-labels-idx1-ubyte.gz')
        test_images_path = os.path.join(self.dataPath, 't10k-images-idx3-ubyte.gz')

        with gzip.open(train_labels_path, 'rb') as lbpath:
            y_train = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

        with gzip.open(train_images_path, 'rb') as imgpath:
            x_train = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(y_train), 28, 28, 1)

        with gzip.open(test_labels_path, 'rb') as lbpath:
            y_test = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

        with gzip.open(test_images_path, 'rb') as imgpath:
            x_test = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(y_test), 28, 28, 1)

        return (x_train, y_train), (x_test, y_test)


if __name__ == '__main__':
    print()
